import time

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from load import get_gnn_inputs
from load_local_refinement import get_gnn_inputs_local_refinement
from losses import compute_loss_multiclass, compute_accuracy_multiclass
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

template_header = '{:<6} {:<10} {:<10} {:<10}'
template_row = '{:<6d} {:<10.4f} {:<10.2f} {:<10.2f}'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cached_graphs = []
cached_labels = []

def from_scores_to_labels_multiclass_batch(pred):
    labels_pred = np.argmax(pred, axis = 2).astype(int)
    return labels_pred
    # return (1,1000)

def get_start_labels_from_first_period(gnn_first_period, Ws, J, device):
    """
    用第一阶段模型从 batched 图里生成初始标签（argmax）。
    Args:
        gnn_first_period: 第一阶段 GNN 模型
        Ws: torch.Tensor, [B, N, N]（可在 CPU 或 GPU 上）
        J: int, 幂次数（传给 get_gnn_inputs）
        device: torch.device
    Returns:
        start_x_batch: np.ndarray, [B, N] 的整型标签
    """
    # 构造第一阶段输入（与你原代码一致）
    WW_np, x_np = get_gnn_inputs(Ws.detach().cpu().numpy(), J)  # (B, N, N, J+3), (B, N, d)

    WW = torch.as_tensor(WW_np, dtype=torch.float32, device=device)
    x  = torch.as_tensor(x_np,  dtype=torch.float32, device=device)

    gnn_first_period.train()

    # 第一阶段只为了产生初始标签，不需要反传梯度
    with torch.no_grad():
        pred = gnn_first_period(WW, x)  # [B, N, n_classes]

    # 转 numpy 再取 argmax（与你的 from_scores_to_labels_multiclass_batch 保持一致）
    start_x_batch = from_scores_to_labels_multiclass_batch(pred.detach().cpu().numpy())  # [B, N]
    return start_x_batch


##Define the train function we need to train the first-period gnn_second_period function
def train_batch_second_period(gnn_first_period, gnn_second_period, optimizer, batch, n_classes, iter, device, args):
    """
    使用 batched 输入训练 gnn_second_period，适配用户自定义的 permutation-aware 损失函数。
    """
    gnn_first_period.train(True)
    gnn_second_period.train(True)

    Ws = batch['adj'].to(device)      # shape: (B, N, N)
    labels = batch['labels'].to(device)  # shape: (B, N)

    start = time.time()

    # === ① 用第一阶段模型得到初始标签（抽出来的函数） ===
    start_x_batch = get_start_labels_from_first_period(gnn_first_period, Ws, args.J, device)  # [B, N]

    # === ② 第二阶段的输入构造（与你原逻辑一致） ===
    WW_np, x_np = get_gnn_inputs_local_refinement(Ws.detach().cpu().numpy(), args.J_second,
                                                  start_x_batch, args.n_classes)  # WW: (B,N,N,J+2), x: (B,N,d)
    WW = torch.as_tensor(WW_np, dtype=torch.float32, device=device)
    x  = torch.as_tensor(x_np,  dtype=torch.float32, device=device)

    optimizer.zero_grad(set_to_none=True)

    # === ③ 第二阶段前向与回传 ===
    pred = gnn_second_period(WW, x)  # [B, N, n_classes]
    loss = compute_loss_multiclass(pred, labels, n_classes)
    loss.backward()

    nn.utils.clip_grad_norm_(gnn_second_period.parameters(), args.clip_grad_norm)
    optimizer.step()

    # === ④ 统计 ===
    acc, _ = compute_accuracy_multiclass(pred, labels, n_classes)
    elapsed_time = time.time() - start
    loss_value = loss.item()

    print(template_header.format(*['iter', 'avg loss', 'avg acc', 'elapsed']))
    print(template_row.format(iter, loss_value, acc, elapsed_time))

    return loss_value, acc

def evaluate_on_loader(gnn_first_period, gnn_second_period, val_loader, n_classes, args, device):
    gnn_first_period.train()
    gnn_second_period.train()

    total_loss, total_acc, total_nmi, total_ari = 0, 0,0,0
    with torch.no_grad():
        for batch in val_loader:
            Ws = batch['adj'].to(device)
            labels = batch['labels'].to(device)

            start_x_batch = get_start_labels_from_first_period(gnn_first_period, Ws, args.J, device)  # [B, N]

            WW, x = get_gnn_inputs_local_refinement(Ws.cpu().numpy(), args.J_second, start_x_batch, args.n_classes)
            WW = WW.clone().detach().to(torch.float32).to(device)
            x = x.clone().detach().to(torch.float32).to(device)

            pred = gnn_second_period(WW, x)
            loss = compute_loss_multiclass(pred, labels, n_classes)
            acc, _ = compute_accuracy_multiclass(pred, labels, n_classes)

            pred = pred.data.cpu().numpy()
            labels_cpu = labels.data.cpu().numpy()
            batch_size = pred.shape[0]
            pred_cpu = from_scores_to_labels_multiclass_batch(pred)
            labels_cpu = labels_cpu.flatten()  # 形状: (1000,)
            pred_cpu = pred_cpu.flatten()

            ari = adjusted_rand_score(labels_cpu, pred_cpu)
            nmi = normalized_mutual_info_score(labels_cpu, pred_cpu)

            total_loss += loss.item()
            total_acc += acc
            total_nmi += nmi
            total_ari += ari

    avg_loss = total_loss / len(val_loader)
    avg_acc = total_acc / len(val_loader)
    avg_nmi = total_nmi / len(val_loader)
    avg_ari = total_ari / len(val_loader)
    return avg_loss, avg_acc, avg_ari, avg_nmi


def train_second_period_with_early_stopping(
        gnn_first_period, gnn_second_period, train_loader, val_loader, n_classes, args,
        epochs=100, patience=6, save_path='best_model.pt', filename="filename_first",
        acc_eps: float = 1e-8,
        loss_eps: float = 1e-12,
):
    gnn_first_period.train()
    gnn_second_period.train()

    optimizer = torch.optim.Adamax(gnn_second_period.parameters(), lr=args.lr)

    # loss_lst = []
    # acc_lst = []
    # val_acc_best = -1
    # patience_counter = 0

    loss_lst, acc_lst = [], []
    best_val_nmi = -1.0
    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        gnn_second_period.train()

        for iter_idx, batch in enumerate(tqdm(train_loader)):
            loss, acc = train_batch_second_period(
                gnn_first_period=gnn_first_period,
                gnn_second_period=gnn_second_period,
                optimizer=optimizer,
                batch=batch,
                n_classes=n_classes,
                iter=iter_idx,
                device=device,
                args=args
            )

            loss_lst.append(loss)
            acc_lst.append(acc)
            torch.cuda.empty_cache()

        # 🧪 验证集评估
        val_loss, val_acc, val_nmi, val_ari = evaluate_on_loader(gnn_first_period, gnn_second_period, val_loader, n_classes, args,
                                               device=device)
        print(f"Validation Loss: {val_loss:.6f}, NMI: {val_nmi:.6f},  Accuracy: {val_acc:.6f}")

        # 中间快照（与你原来一致）
        torch.save(gnn_second_period.cpu(), filename)
        if torch.cuda.is_available():
            gnn_second_period = gnn_second_period.to(device)

        # ✅ 刷新最佳：先比 acc；若 acc 打平，再比 loss
        improved = False
        if val_nmi > best_val_nmi + acc_eps:
            reason = "val_nmi improved"
            improved = True
        elif abs(val_nmi - best_val_nmi) <= acc_eps and val_loss < best_val_loss - loss_eps:
            reason = "val_nmi tie, val_loss improved"
            improved = True
        else:
            reason = None

        if improved:
            best_val_nmi = val_nmi
            best_val_loss = val_loss
            patience_counter = 0

            # 保存最佳模型（与你原来一致）
            torch.save(gnn_second_period.cpu(), save_path)
            # ✅ 可选：更安全的保存方式（推荐）
            # torch.save(gnn.state_dict(), save_path)

            print(f"New best model saved ({reason}). best_nmi={best_val_nmi:.6f}, best_loss={best_val_loss:.6f}")
            if torch.cuda.is_available():
                gnn_second_period = gnn_second_period.to(device)
        else:
            patience_counter += 1
            print(f"No improvement ({patience_counter}/{patience}). "
                  f"best_nmi={best_val_nmi:.6f}, best_loss={best_val_loss:.6f}")

        # ⛔ 提前停止
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

        torch.cuda.empty_cache()  # 可选：按 epoch 清一次

    return loss_lst, acc_lst
